/* 
 * File:   ImageWriter.cpp
 * Author: yuri
 * 
 * Created on June 6, 2011, 2:59 PM
 */

#include "include/ImageWriter.hpp"
#include <iostream>
#include <cstdlib>
#include <fstream>
#include <vector>

#include "include/messages.h"
#include "include/io.hpp"
#include "include/Node.hpp"
#include "lzma/LzmaLib.h"
#include "fileio/fileio.hpp"
#include "ImageStatistics.hpp"


using namespace std;

enum IW_STATE {  INVALID = -1, FILE_OPEN, HEADER_WRITTEN, FILE_SAVED, FILE_CLOSE } state;

int MIN_PATH_LENGTH = 2;
int MIN_SUM_RADIUS  = 10;
int MIN_OBJECT_SIZE = 5;

ImageWriter::ImageWriter(const char *fname) {
    of.open(fname, ios_base::out | ios::binary);

    if (of.good())
        state = FILE_OPEN;
    else {
        PRINT(MSG_ERROR, "Could not open file for output\n");
        exit(1);
    }

}

/**
 * Compress the output using the LZMA library, and write the output to the file.
 * @return 1 when result was OK, otherwise 0.
 */
int ImageWriter::save(){
    SizeT propsSize = LZMA_PROPS_SIZE;
    SizeT destLen = 2* ofBuffer.str().size() + 512;
    unsigned char *out = new unsigned char[destLen];

    PRINT(MSG_NORMAL, "Compressing output using LZMA...\n");
    int res = LzmaCompress(&out[propsSize], &destLen, (const unsigned char*) ofBuffer.str().c_str(), ofBuffer.str().size(), &out[0], &propsSize, 9, 0, -1, -1, -1, -1, 1);
    PRINT(MSG_NORMAL, "Compression Done. Original size: %.5fKB, compressed size: %.5fKB.\n", ofBuffer.str().size() / 1024.0, (destLen + propsSize) / 1024.0);
    
    
    /* DEBUG -- Write statistics */
    ofstream *f=IS_get_statistics_ofstream();
    (*f) << "FILE_SIZE=" << (destLen + propsSize) << ";" << endl;
    /* END DEBUG */
    
    int origlen = ofBuffer.str().size();

    of.write((const char *) &origlen, sizeof(int));
    of.write((const char *) &out[0], destLen + LZMA_PROPS_SIZE);

    state = FILE_SAVED;

    of.close();
    
    delete [] out;
    return res == SZ_OK;
}

void ImageWriter::writeHeader(unsigned int width, unsigned int height) {
    uint16_t out16;
    if(width > (2 << 16) || height > (2 << 16)){
        PRINT(MSG_ERROR, "Due to compression the width and height of the file are limited to 65536px.\n");
        exit(1);
    }

    
    ofBuffer.write((char *) &WRITER_FILE_VERSION_NUMBER, sizeof(uint16_t));
    out16 = width;
    ofBuffer.write((char *) &out16, sizeof (uint16_t));
    out16 = height;
    ofBuffer.write((char *) &out16, sizeof (uint16_t));
  
    state = HEADER_WRITTEN;
}

ImageWriter::~ImageWriter() {
    if(state != FILE_SAVED){
        save();
    }
    of.close();
}

void ImageWriter::writeChain(coord3D_t prev, coord3D_t cur){
    unsigned char out8=8;
    if(prev.first == (cur.first - 1) && prev.second == (cur.second - 1) ) out8= 0;
    if(prev.first == (cur.first    ) && prev.second == (cur.second - 1) ) out8= 1;
    if(prev.first == (cur.first + 1) && prev.second == (cur.second - 1) ) out8= 2;
    if(prev.first == (cur.first - 1) && prev.second == (cur.second    ) ) out8= 3;
    if(prev.first == (cur.first + 1) && prev.second == (cur.second    ) ) out8= 4;
    if(prev.first == (cur.first - 1) && prev.second == (cur.second + 1) ) out8= 5;
    if(prev.first == (cur.first    ) && prev.second == (cur.second + 1) ) out8= 6;
    if(prev.first == (cur.first + 1) && prev.second == (cur.second + 1) ) out8= 7;

    if(out8 == 8)
        PRINT(MSG_ERROR, "Could not encode neighbours...\n");

    if((cur.third - prev.third) + 2 > 4 ){
        PRINT(MSG_ERROR, "Error encoding radius difference (outside range [-2, 2]): %d -> %d\n", cur.third, prev.third);
    }
    
    out8 = (out8 << 4) + ((cur.third - prev.third) + 8);

    ofBuffer.write((const char *)&out8, sizeof(char));
   
}

void ImageWriter::writePath(skel_tree_t *st, int pLength, bool rightMost){
    uint16_t out16[3];

    /* Did we reach a leaf? */
    if(st->numChildren() == 0){
        /* if rightMost, then write END tag, we have reached end of object. */
        if(rightMost){
            ofBuffer.write((const char*) &END_TAG, 1*sizeof(char));
        }else{
            ofBuffer.write((const char*) &FORK_TAG, 1*sizeof(char));
            out16[0] = pLength;
            ofBuffer.write((const char*) &out16[0], 2*sizeof(char));
        }
        return;
    }
    
    /* Not a leaf, does it have exactly 1 child (is it a continuous path?) ?*/
    if(st->numChildren() == 1){
        writeChain(st->getValue(), st->getChild(0)->getValue());
        writePath(st->getChild(0), pLength+1, rightMost);
        return;
    }

    /* Fork coming up! */
    if(st->numChildren() > 1){

        /* All "non-rightmost" children: */
        for(int i=0; i< (st->numChildren()-1) ; ++i){
            writeChain(st->getValue(), st->getChild(i)->getValue());
            writePath(st->getChild(i), 1, false);
        }
        /* Treat rightmost child different, pass a longer path length, so it jumps back further after being done with the last branch. */
        writeChain(st->getValue(), st->getChild( st->numChildren()-1 )->getValue());
        writePath(st->getChild( st->numChildren()-1 ), 1+pLength, rightMost);
    }

}

void removeSmallPaths(skel_tree_t *st){

    if(st->numChildren() > 1)
    for(int i=0; i<st->numChildren() && st->numChildren() > 1; ++i){
        if(st->getChild(i)->numRChildren() < MIN_PATH_LENGTH || st->getChild(i)->importance() < MIN_SUM_RADIUS){
            st->removeRChild(i);
            --i;
        }
    }

    //if(st->numChildren() == 1)
    //    removeSmallPaths(st->getChild(0));
    for(int i=0;i<st->numChildren(); ++i){
        removeSmallPaths(st->getChild(i));
    }
    
}


void removeSmallObjects(skel_tree_t *st){
    for(int i=0; i<st->numChildren(); ++i){
        if(st->getChild(i)->numRChildren() < MIN_OBJECT_SIZE || st->getChild(i)->importance() < MIN_SUM_RADIUS){
            st->removeRChild(i);
            --i;
        }
    }
}

/* Due to some preprocessing steps which simplify the skeleton we could generate "new points", Points which are
 actually outside of an object may now become skeleton points (e.g. by dilation). These points have a radius of
 0. If these points are at the end of a skeleton path, we can safely remove them.*/
void removeInvisibleEndpoints(skel_tree_t *st){
    for(int i=0; i<st->numChildren(); ++i){
        removeInvisibleEndpoints(st->getChild(i));

        if(st->getChild(i)->numRChildren() == 0 && st->getChild(i)->getValue().third == 0){
            st->removeRChild(i);
            --i;
        }
    }   
}

void filterTree(skel_tree_t *st){
    
    removeInvisibleEndpoints(st);
    removeSmallPaths(st);
    removeSmallObjects(st);

}

/* See file SIRFORMAT for an explanation of how the file is stored. */
void ImageWriter::writeLayer(FIELD<float> *skel, FIELD<float> *dt, unsigned char intensity) {
    uint16_t out16[32];
    //uint8_t out8[32];
    skel_tree_t *st = traceLayer(skel, dt);
    
    /* DEBUG -- Write statistics */
    ofstream *f=IS_get_statistics_ofstream();
    (*f) << "number_of_objects_pretreefilter(" << (int)intensity+1 << ")=" << st->numChildren() << ";" << endl;
    (*f) << "number_of_on_pixels_pretreefilter(" << (int)intensity+1 << ")=" << st->numRChildren() << ";" << endl;
            
    /* END DEBUG */

    filterTree(st);

    /* DEBUG -- Write statistics */
    (*f) << "number_of_objects_posttreefilter(" << (int)intensity+1 << ")=" << st->numChildren() << ";" << endl;
    (*f) << "number_of_on_pixels_posttreefilter(" << (int)intensity+1 << ")=" << st->numRChildren() << ";" << endl;
            
    /* END DEBUG */

    
    if (state != HEADER_WRITTEN) {
        PRINT(MSG_ERROR, "Header has not been written yet -- Aborting.\n");
        return;
    }

    /* Remove overhead for empty layers. */
    if(st->numChildren() == 0){ delete st; return; }
    
    /* Write intensity and number of paths. */
    out16[0] = (uint16_t) st->numChildren();
    ofBuffer.write((char *)&intensity, 1*sizeof(char));
    ofBuffer.write((char *)&out16,     2*sizeof(char));
    
    /* Top layer are disjunct paths. Treat them as separate objects */
    for(int child=0; child < st->numChildren(); ++child){
        skel_tree_t *curnode = (*st)[child];
        coord3D_t curpoint = curnode->getValue();
        out16[0] = curpoint.first;
        out16[1] = curpoint.second;
        out16[2] = curpoint.third;
        ofBuffer.write((char *) out16, 3*sizeof(uint16_t));
        writePath(curnode, 1,true);
    }

    delete st;

}

skel_tree_t* ImageWriter::traceLayer(FIELD<float> *skel, FIELD<float> *dt){
    skel_tree_t *root;
    coord3D_t rootCoord = coord3D_t(-1,-1,-1);
    root = new skel_tree_t( rootCoord );
    for(int y=0; y<skel->dimY(); ++y)
        for(int x=0; x<skel->dimX(); ++x){
            if(skel->fvalue(x,y) > 0){
                root->addChild(tracePath(x,y, skel, dt));
            }
        }
    return root;
}

skel_tree_t *ImageWriter::tracePath(int x, int y, FIELD<float> *skel, FIELD<float> *dt){
    coord2D_t n;
    coord2D_list_t *neigh;
    skel_tree_t *path;
    if(skel->fvalue(x,y)==0){
        PRINT(MSG_ERROR, "Reached invalid point.\n");
        return NULL;
    }

    /* Create new node and add to root */
    path = new Node<coord3D_t>(coord3D_t(x,y,dt->fvalue(x,y)));
    skel->set(x,y,0);
    
    neigh = neighbours(x,y,skel);
    /* Add children */
    while(neigh->size() > 0){
        n = *(neigh->begin());
        path->addChild(tracePath(n.first, n.second, skel, dt));
        delete neigh;
        neigh = neighbours(x,y,skel);
    }

    delete neigh;
    return path;
}

coord2D_list_t * ImageWriter::neighbours(int x, int y, FIELD<float> *skel){
    coord2D_list_t *neigh = new coord2D_list_t();
    int n[8] = {1,1,1,1,1,1,1,1};

    /* Check if we are hitting a boundary on the image */
    if(x <= 0 )             {        n[0]=0;        n[3]=0;        n[5]=0;    }
    if(x >= skel->dimX()-1) {        n[2]=0;        n[4]=0;        n[7]=0;    }
    if(y <= 0)              {        n[0]=0;        n[1]=0;        n[2]=0;    }
    if(y >= skel->dimY()-1) {        n[5]=0;        n[6]=0;        n[7]=0;    }

    /* For all valid coordinates in the 3x3 region: check for neighbours*/
    if ((n[0] != 0) && (skel->value(x-1,y-1) > 0)){   neigh->push_back(coord2D_t(x-1,y-1));  }
    if ((n[1] != 0) && (skel->value(x  ,y-1) > 0)){   neigh->push_back(coord2D_t(x  ,y-1));  }
    if ((n[2] != 0) && (skel->value(x+1,y-1) > 0)){   neigh->push_back(coord2D_t(x+1,y-1));  }
    if ((n[3] != 0) && (skel->value(x-1,y  ) > 0)){   neigh->push_back(coord2D_t(x-1,y  ));  }
    if ((n[4] != 0) && (skel->value(x+1,y  ) > 0)){   neigh->push_back(coord2D_t(x+1,y  ));  }
    if ((n[5] != 0) && (skel->value(x-1,y+1) > 0)){   neigh->push_back(coord2D_t(x-1,y+1));  }
    if ((n[6] != 0) && (skel->value(x  ,y+1) > 0)){   neigh->push_back(coord2D_t(x  ,y+1));  }
    if ((n[7] != 0) && (skel->value(x+1,y+1) > 0)){   neigh->push_back(coord2D_t(x+1,y+1));  }

    return neigh;

}